-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Support Mamba2 (Codestral Mamba) #9292
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Notes on current state:
|
@tlrmchlsmth thank you very much for your work! Kindly asking, do you have any updates since your last post? |
hi @yury-tokpanov, sorry, I am focusing on other things right now -- namely support for tensor parallelism in V1. Currently this PR need some fairly hard debugging. Hopefully I'll be able to resume working on this PR once I finish that work. |
@tlrmchlsmth Thanks for the update! I work at Zyphra, and we are interested in incorporating our Zamba2 model into vLLM (#9382). I'm using your PR as a starting point, since we need mamba2 layers for that. If you're open to it, I'd be happy to help with finishing this PR. |
Hi @yury-tokpanov, yes I would be very open to that! If you need any pointers/advice/help please feel free to reach out on the vllm developer slack (https://communityinviter.com/apps/vllm-dev/join-vllm-developers-slack) |
@tlrmchlsmth @yury-tokpanov We also recently opened a PR to add a new model (Bamba) that also requires mamba v2 support. For continuous batching it works, but supporting chunked prefill can be quite challenging. cc: @ani300, @raghukiran1224, @njhill |
Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
bc9b5cf
to
17923ad
Compare
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Note that at this point, much of the implementation is now taken directly from #10909 |
@tlrmchlsmth @fabianlim thanks for all your work! I have our internal implementation of Zamba2 based of previous version of this PR. I'm going to rebase it. Would you recommend using this branch or the one from bamba PR #10909? |
assert not is_lora_enabled | ||
|
||
self.config = config | ||
self.padding_idx = config.pad_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this one used anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into cleaning this up. I see several models that have similar seemingly unused padding_idx
variables
This pull request has merge conflicts that must be resolved before it can be |
# For eager just take the scheduler_config if avail | ||
self.max_batch_size = self.scheduler_config.max_num_seqs | ||
else: | ||
self.max_batch_size = 128 + 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the reason for setting max_batch_size at a much lower value than in mamba or jamba?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was to avoid out-of-memory issues that I was seeing with CodestralMamba
Before landing I will check if we can simplify this code - I'm not sure when self.scheduler_config
would be None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didnt do a very careful comparison, but given that mamba2's design is to have higher headdim, the cache size of mamba2 should be much higher than that of mamba1. Hence we should not be expecting to allocate caches for same number of seqs as in mamba1 I believe.
@yury-tokpanov I'd recommend using the mixer2 implementation from this PR -- I think there are only one or two small changes from the Bamba PR but notably I did fix one correctness issue with a |
Signed-off-by: Tyler Michael Smith <[email protected]>
Heads up that Mamba tests currently fail on main: #12465 It would be great if you could solve the issue in this PR as well! |
I'm looking into it. I see the following error in the logs, but so far am unable to reproduce this issue on an H100
Edit: fixed error message |
The other update I have is that I am trying to reproduce the humaneval results reported here, using https://github.com/neuralmagic/evalplus, but no luck so far:
repro steps:
|
GSM8k results:
|
I am unable to reproduce eval results for our Zamba2 model with lm_eval both for some loglikelihood tasks (winogrande, arc tasks) and generation tasks (like gsm8k), while some loglikelihood tasks are fine for some reason (mmlu, hellaswag). When I dig deeper and compare the outputs layer by layer with our HF implementation, I see there is a small discrepancy in mamba2 layers starting from the first one, and it accumulates over the whole network. Final logits are within 1% of each other between the two implementations. Going to check what's going on. @fabianlim were you able to reproduce bamba results in lm_eval with your vllm implementation? |
@yury-tokpanov no i have never tried yet reproducing benches on vllm. I have to try it myself |
@yury-tokpanov with @tlrmchlsmth's help we have verified also the HF: |
The computation of gated RMS norm depends on the number of Mamba2 groups: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/layernorm_gated.py#L32 . Our 7B model has 2 groups, so it definitely affects it. I'm still chasing other discrepancies. Seems like Codestral 7B uses 8 groups, so it'll definitely be an issue for that model as well: https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1/blob/main/config.json |
|
||
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated | ||
@CustomOp.register("mixer2_gated_rms_norm") | ||
class Mixer2RMSNormGated(CustomOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to implement case of n_groups > 1
: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/layernorm_gated.py#L31
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer! I'll take a look tomorrow.
After fixing gated rms norm, I was able to match gsm8k results for our 7B model. I still see some tasks numbers being lower for some reason, so going to investigate further. @fabianlim did you compare other evals? The ones that I see are different and on your evals table are ARC-C and Winogrande. Btw, we just merged Zamba2 implementation into transformers, so once I'm done with vLLM implementation, I'll create a PR here: need to fix/check correctness and evals, add PP support then clean - our architecture uses shared transformer layers with LoRAs in them, so I'll need to think a little bit about how to adapt it to the vLLM style, seems like there is a big refactoring going on, which already caught me with kv cache. |
@yury-tokpanov I can reproduce the HF
VLLM
Also for |
@yury-tokpanov could you share what you did to fix gated rms norm? I don't see n_groups being handled in zamba here https://github.com/huggingface/transformers/blob/main/src/transformers/models/zamba2/modeling_zamba2.py#L64-L79 |
We have a new PR in transformers fixing this issue: huggingface/transformers#35943 I did the same thing for my vLLM implementation, but I ignored TP for now as I've been testing our models with TP=1 so far. Also, I tested with the original triton implementation from mamba2 repo to make sure I'm getting the same eval results:
|
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Updated to handle groups in Mixer2RMSNormGated -- GSM8K results are much improved:
Will re-run humaneval as well |
humaneval results looking much better as well:
|
input_size=conv_kernel_size, | ||
output_size=self.conv_dim, | ||
bias=use_conv_bias, | ||
quant_config=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious, how did you decide which layers should be quantizeable and which not? Did you run experiments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were a few reasons , including the following: i) llm-compressor does not quantize conv1d layers, 2) the conv1d kernels do not support fp8. In our blog we have a few numbers when using this scheme. cc: @nwang-ibm
I rebased using the latest version of this PR, and now I'm getting this error from
I see there was this FA3 revert commit for ViT MHA, reporting the same error: #12445 A bit weird, since I'm not using FA3, but something is broken nonetheless. |
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
See fabianlim@5d240a0 Signed-off-by: Tyler Michael Smith <[email protected]>
Add support for Mamba2. Not thoroughly tested yet, but Codestral Mamba has legible outputs.
Todo:
mamba_chunk_scan_combined
kernel to avoid the dependency onmamba_ssm
Closes #6479